{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import h2o4gpu\n",
    "import h2o4gpu.util.import_data as io\n",
    "import h2o4gpu.util.metrics as metrics\n",
    "from tabulate import tabulate\n",
    "from h2o4gpu.solvers.linear_regression import LinearRegression\n",
    "import pandas as pd\n",
    "import os, sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading Data with Pandas\n",
      "(999, 9733)\n",
      "Original m=999 n=9732\n",
      "Size of Train rows=800 & valid rows=199\n",
      "Size of Train cols=9732 valid cols=9732\n",
      "Size of Train cols=9733 & valid cols=9733 after adding intercept column\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Import Data for H2O GPU Edition\n",
    "\n",
    "This function will read in data and prepare it for H2O4GPU's GLM solver\n",
    "\n",
    "Parameters\n",
    "----------\n",
    "data_path : str\n",
    "             A path to a dataset (The dataset needs to be all numeric)\n",
    "use_pandas : bool\n",
    "              Indicate if Pandas should be used to parse\n",
    "intercept : bool\n",
    "              Indicate if intercept term is needed\n",
    "valid_fraction : float\n",
    "                  Percentage of dataset reserved for a validation set\n",
    "classification : bool\n",
    "                  Classification problem?\n",
    "Returns\n",
    "-------\n",
    "If valid_fraction > 0 it will return the following:\n",
    "    train_x: numpy array of train input variables\n",
    "    train_y: numpy array of y variable\n",
    "    valid_x: numpy array of valid input variables\n",
    "    valid_y: numpy array of valid y variable\n",
    "    family : string that would either be \"logistic\" if classification is set to True, otherwise \"elasticnet\"\n",
    "If valid_fraction == 0 it will return the following:\n",
    "    train_x: numpy array of train input variables\n",
    "    train_y: numpy array of y variable\n",
    "    family : string that would either be \"logistic\" if classification is set to True, otherwise \"elasticnet\"\n",
    "\"\"\"\n",
    "import os\n",
    "if not os.path.exists(\"ipums_1k.csv\"):\n",
    "    !wget https://s3.amazonaws.com/h2o-public-test-data/h2o4gpu/open_data/ipums_1k.csv\n",
    "train_x,train_y,valid_x,valid_y,family=io.import_data(data_path=\"ipums_1k.csv\", \n",
    "                                                        use_pandas=True, \n",
    "                                                        intercept=True,\n",
    "                                                        valid_fraction=0.2,\n",
    "                                                        classification=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting up solver\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Set up instance of H2O4GPU's GLM solver with default parameters\n",
    "\n",
    "Need to pass in `family` to indicate problem type to solve\n",
    "\"\"\"\n",
    "print(\"Setting up solver\")\n",
    "model = LinearRegression()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Solving\n",
      "CPU times: user 3.57 s, sys: 4.25 s, total: 7.82 s\n",
      "Wall time: 9.43 s\n",
      "Done Solving\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Fit Linear Regression Solver\n",
    "\"\"\"\n",
    "print(\"Solving\")\n",
    "%time model.fit(train_x, train_y)\n",
    "print(\"Done Solving\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predictions per alpha\n",
      "[[  5.89144754e+00   1.13267566e+03   6.29999883e+04   6.59062411e-04\n",
      "    1.84079254e-04   1.24079944e-03   5.32921113e-04   7.96878594e-05\n",
      "    1.60439056e-03   1.09333740e+04  -2.98141211e-04  -1.17179035e-04\n",
      "   -7.60835654e-04   1.70000098e+04  -1.99286174e-03   2.09998965e+04\n",
      "    4.80395215e+03   2.33241566e-03  -8.13281105e-04   1.05562675e+02\n",
      "   -1.16573608e+03  -8.37087282e-05   1.42230908e-03   5.89264679e+00\n",
      "   -8.03512870e-04   7.82845775e-04   1.78153267e+01   3.51782056e+03\n",
      "   -7.56430149e-04   2.17146997e-04   3.01949005e-03   1.68121606e+03\n",
      "    1.24978181e-03   6.65558044e+02   6.65558472e+02   7.39584412e+02\n",
      "   -1.86518743e-03  -1.08923845e-03  -9.00257321e-04  -3.39461910e-03\n",
      "   -9.70140973e-04  -2.09905044e-03   2.99995117e+03   2.99994995e+03\n",
      "    3.31795163e-04   1.76473870e-04  -9.76434851e-04   3.27940781e+04\n",
      "   -1.16139160e+03   1.70000020e+04  -1.00740243e-03  -5.56127401e-04\n",
      "    2.79090251e-03   9.76173340e+03  -8.03297968e-04  -8.05274758e-04\n",
      "   -4.33895206e+00  -4.34106302e+00   1.31012779e-03  -3.67362198e-04\n",
      "    1.53295137e+04  -5.08083263e-04   1.28081255e+01   1.28096237e+01\n",
      "   -1.25369593e-03  -5.00832975e-04   6.62199512e+03   2.28592684e-03\n",
      "    3.23815295e-03  -2.30143103e-03  -1.38945307e-03   3.35133667e+01\n",
      "    3.61817307e-04  -5.42415306e-04   2.28227480e+04  -6.01411331e-04\n",
      "   -7.20854441e-04  -1.29542570e-03   4.00000039e+04  -5.33683749e-04\n",
      "   -3.39027471e-03   1.20575994e-03   1.97378616e-03   1.27967435e-03\n",
      "   -1.02277088e-03   1.54535065e-03   4.33333301e+03  -1.07247732e-03\n",
      "    4.44212696e-03  -8.69292882e-04  -1.16139746e+03  -5.37149422e-03\n",
      "   -1.43741770e-03   1.05559372e+02   1.07938703e-03  -1.61416456e-03\n",
      "    9.04899905e-04   1.61994481e-03  -1.12700836e-04   1.64776978e+03\n",
      "    7.07465363e+00  -2.42081890e-03  -2.04855669e-03  -5.06862346e-03\n",
      "    6.19494629e+03   2.10829033e-03   2.63217656e+04  -2.35003047e-03\n",
      "    1.78778986e-03  -2.40198919e-03  -1.80959515e-03   3.19604343e-03\n",
      "    1.60988815e+02   1.39282248e-03  -9.29115234e+03   9.72605601e-04\n",
      "    2.95070605e-03   3.33333472e+03   1.15379086e-03   4.34018652e+03\n",
      "    2.14496670e+01   2.00000039e+04   7.40249329e+01   1.21761928e-03\n",
      "    7.07400751e+00   3.08897253e-03   2.54534454e+02  -1.02419220e-03\n",
      "   -1.87281694e-06   2.40197559e+03  -8.77125887e-04  -2.28053913e-03\n",
      "    3.45822395e-04   6.14803352e+01   2.20248662e-03  -8.58849147e-04\n",
      "    6.62199219e+03  -6.38691417e-04   1.38237793e-03  -3.25356983e-03\n",
      "    1.26637286e-03   1.14995352e+04   1.14995332e+04  -1.42054725e-03\n",
      "   -1.45975500e-04  -4.04204766e-04  -1.13091071e-03   1.24716840e-04\n",
      "    2.50666582e+04   3.27940781e+04   3.98529346e+03   2.43660616e-05\n",
      "    6.85144806e+00   5.89575815e+00   5.89360332e+00   2.79408414e-03\n",
      "   -2.71717296e-03   3.30178021e-03   2.14033891e-04   2.04296643e-03\n",
      "    7.55258254e-04   7.22505312e+04   1.84186152e-03   3.83563398e-04\n",
      "    1.40931597e-03   1.88023412e+00   1.87250030e+00   1.94328779e-03\n",
      "    1.16607901e-02   3.60594015e-03   1.17056631e-03  -1.62261806e-03\n",
      "   -1.75390567e-03   1.99847925e-03   3.92020447e-04  -6.41307514e-03\n",
      "    8.78024264e-04   1.22488418e+04  -5.13260020e-03   1.98827637e-03\n",
      "    9.00543586e-04  -1.09907950e-03   2.40197583e+03  -1.50204939e-03\n",
      "    9.92256333e-04   2.40197485e+03   7.73297576e-03  -2.87989690e-03\n",
      "    2.85557588e-04   1.13267725e+03   3.10035120e-03   8.91121454e-04\n",
      "    4.71562892e-03   2.83576758e-03   3.73312924e-03   5.27994141e+03\n",
      "    4.69563529e-04   4.40287607e-04   2.39236397e-03]]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Make predictions on validation set\n",
    "\"\"\"\n",
    "print(\"Predictions per alpha\")\n",
    "preds = model.predict(valid_x)\n",
    "print(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "del model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6.1(pyenv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
